#
# Copyright (c) 2024 Bitdefender
# SPDX-License-Identifier: Apache-2.0
#
import sys
import re
import json
import string


class UniqDict(dict):
    def __setitem__(self, key: str, value: str) -> type(None):
        if key in self.keys():
            raise ValueError("key is already present : %s" % (key))
        else:
            return super(UniqDict, self).__setitem__(key, value)


class DecodeShemuParser(object):
    def __init__(self, obj: str):
        self._obj = obj
        self._crt = 0
        self._data = UniqDict()

    def rdline(self) -> str:
        """
        Reads the line corresponding to the current index from the string containing the result.
        This function also increment the index to the current line.
        """
        if self._crt == len(self._obj):
            return None

        out = self._obj[self._crt]
        self._crt += 1

        return out

    def rdnline(self) -> str:
        """
        Reads the next line corresponding to the current index from the string containing the result.
        This function does not increment the index to the current line.
        """
        if self._crt == len(self._obj):
            return None

        out = self._obj[self._crt]
        return out


class ShemuResult(DecodeShemuParser):
    def __init__(self, obj: str):
        DecodeShemuParser.__init__(self, obj)

        self.process()

    def process(self) -> type(None):
        """
        Parses the emulation result generated by disasmtool and stores it in a dictionary as key:value pair as follows:
        {
            "Emulation terminated with status": "0x0000000a",
            "flags:": "0x80",
            "NOPs": "0",
            "NULLs": "0",
            "total instructions": "2",
            "unique instructions": "1",
            "SHEMU_FLAG-0": "SHEMU_FLAG_HEAVENS_GATE"
        }
        """
        self._obj = self._obj.split("\n")
        self._obj = list(filter(None, self._obj))
        # Emulation terminated with status 0x0000000a, flags: 0x10, 0 NOPs, 0 NULLs, 10 total instructions, 10 unique instructions
        line = self.rdline()

        tokens = line.split(",")
        tokens = list(filter(None, tokens))
        tokens = list(filter(lambda item: item.lstrip().rstrip(), tokens))
        for token in tokens:
            tsplit = token.lstrip().rstrip().split(" ")
            if (
                all(c in string.hexdigits + "x" for c in tsplit[-1])
                or tsplit[-1].isnumeric()
            ):
                val = tsplit[-1]
                key = " ".join(tsplit[:-1])
            else:
                val = tsplit[0]
                key = " ".join(tsplit[1:])
            self._data[key] = val

        line = self.rdline()
        cnt = 0
        while line:
            self._data["SHEMU_FLAG-%s" % (cnt)] = line.lstrip().rstrip()
            cnt += 1
            line = self.rdline()


class ShemuInstrux(DecodeShemuParser):
    def __init__(self, obj: str):
        DecodeShemuParser.__init__(self, obj)

        self.process()

    def process(self) -> type(None):
        """
        Parses the emulation result generated by disasmtool and stores it in a dictionary as key:value pair as follows:
        {
            "RAX": "0x0000000000000000",
            "RCX": "0x0000000000000000",
            "RDX": "0x0000000000000000",
            "RBX": "0x0000000000000000",
            "RBP": "0x0000000000000000",
            "RSI": "0x0000000000000000",
            ...
            "R28": "0x0000000000000000",
            "R29": "0x0000000000000000",
            "R30": "0x0000000000000000",
            "R31": "0x0000000000000000",
            "RIP": "0x0000000000200000",
            "RFLAGS": "0x0000000000000202"
        }
        """
        self._obj = self._obj.split("\n")
        line = self.rdline()
        cnt = 0
        while line:
            if " = " in line:
                tokens = re.findall(r"\w+\s*=\s*[0x]*[\da-f]{4,16}", line)
                for token in tokens:
                    key = token.lstrip().rstrip().split("=")[0].lstrip().rstrip()
                    val = token.lstrip().rstrip().split("=")[1].lstrip().rstrip()
                    self._data[key] = val
            if "IP: " in line or "PC: " in line:
                tokens = re.findall(r"\w\w\s*:\s*0x[\da-f]{16}", line)
                key = tokens[0].lstrip().rstrip().split(":")[0].lstrip().rstrip()
                val = tokens[0].lstrip().rstrip().split(":")[1].lstrip().rstrip()
                self._data["%s-%s" % (key, "INFO")] = val

                tokens = line.split('  ')
                key = "InstructionBytes"
                val = tokens[1].lstrip().rstrip()
                self._data[key] = val

                key = "InstructionText"
                val = " ".join(tokens[2:]).lstrip().rstrip()
                self._data[key] = val
            elif "Detection: " in line:
                tokens = re.findall(r"\w+\s*:\s*0x\d{16}", line)
                key = tokens[0].lstrip().rstrip().split(":")[0].lstrip().rstrip()
                val = tokens[0].lstrip().rstrip().split(":")[1].lstrip().rstrip()
                self._data["%s-%d" % (key, cnt)] = val
                cnt += 1
            elif ":" in line:
                tokens = re.findall(r"\w\w\s*:\s*\d{1}", line)
                for token in tokens:
                    key = token.lstrip().rstrip().split(":")[0].lstrip().rstrip()
                    val = token.lstrip().rstrip().split(":")[1].lstrip().rstrip()
                    self._data[key] = val

            line = self.rdline()


class DecodeInstrux(DecodeShemuParser):
    def __init__(self, obj: str):
        DecodeShemuParser.__init__(self, obj)

        self.process()

    def process(self) -> type(None):
        """
        Parses an instruction generated by disasmtool and stores it in a dictionary as key:value pair as follows:
        {
            "InstructionBytes": "c4e2784900",
            "InstructionText": "LDTILECFG zmmword ptr [rax]",
            "RIP": "0000000000000000",
            "DSIZE": "32",
            "ASIZE": "64",
            "VLEN": "-",
            "ISA Set": "AMX-TILE",
             ...
            "Operand-0": {
                "Operand": "0",
                "Acc": "R-",
                "Type": "Memory",
                "Size": "64",
                "RawSize": "64",
                "Encoding": "M",
                "Segment": "3",
                "Base": "0"
            }
        }
        """
        self._obj = self._obj.split("\n")

        line = self.rdline()
        while line:
            # 0000000000000000 c4e2784900                      LDTILECFG zmmword ptr [rax]
            # 0000000000000000 62                              db 0x62 (0x80000002)
            if re.search("^[0-9A-F]{16}", line):
                tokens = line.split(" ")
                tokens = list(filter(None, tokens))
                tokens = list(filter(lambda item: item.lstrip().rstrip(), tokens))
                self._data["InstructionBytes"] = tokens[1].lstrip().rstrip()
                self._data["InstructionText"] = " ".join(tokens[2:]).lstrip().rstrip()
                self._data["RIP"] = tokens[0].lstrip().rstrip()

            # Operand: 0, Acc:  RW,  Type:   Register, Size:  1, RawSize:  1, Encoding: M, RegType:  General Purpose,
            # RegSize:  1, RegId: 22, RegCount: 1
            if "Operand:" in line:
                while self.rdnline() and "Operand:" not in self.rdnline():
                    # if line.endswith(", ") or line.endswith(","):
                    line += self.rdline()

                local = UniqDict()
                tokens = line.split(",")
                tokens = list(filter(lambda item: item.lstrip().rstrip(), tokens))
                dc = 0
                for token in tokens:
                    if key == "Decorator":
                        key = "%s-%s" % (key, dc)
                        dc += 1
                    key = token.lstrip().rstrip().split(":")[0].lstrip().rstrip()
                    val = token.lstrip().rstrip().split(":")[1].lstrip().rstrip()
                    local[key] = val

                key = tokens[0].lstrip().rstrip().split(":")[0].lstrip().rstrip()
                val = tokens[0].lstrip().rstrip().split(":")[1].lstrip().rstrip()
                self._data["%s-%s" % (key, val)] = local

            # EVEX Tuple Type: Tuple 1 scalar, 8 bit
            # EVEX Tuple Type: Full
            elif "EVEX" in line:
                tokens = line.split(":")
                key = tokens[0].lstrip().rstrip()
                val = tokens[1].lstrip().rstrip()
                self._data[key] = val

            elif ": " in line:
                tokens = line.split(",")
                tokens = list(filter(lambda item: item.lstrip().rstrip(), tokens))
                for token in tokens:
                    key = token.lstrip().rstrip().split(":")[0].lstrip().rstrip()
                    val = token.lstrip().rstrip().split(":")[1].lstrip().rstrip()
                    self._data[key] = val

            line = self.rdline()
